Maximum sum BST in binary tree [DFS+Stack,DFS+Recursion]¶
Time: O(N); Space: O(H); hard
Given a binary tree root, the task is to return the maximum sum of all keys of any sub-tree which is also a Binary Search Tree (BST).
Assume a BST is defined as follows:
The left subtree of a node contains only nodes with keys less than the node’s key.
The right subtree of a node contains only nodes with keys greater than the node’s key.
Both the left and right subtrees must also be binary search trees.
Example 1:
Input: root = {TreeNode} [1,4,3,2,4,2,5,null,null,null,null,null,null,4,6]
Output: 20
Explanation:
Maximum sum in a valid Binary search tree is obtained in root node with key equal to 3.
Example 2:
Input: root = {TreeNode} [4,3,null,1,2]
Output: 2
Explanation: Maximum sum in a valid Binary search tree is obtained in a single root node with key equal to 2.
Example 3:
Input: root = {TreeNode} [-4,-2,-5]
Output: 0
Explanation:
All values are negatives. Return an empty BST.
Example 4:
Input: root = {TreeNode} [2,1,3]
Output: 6
Example 5:
Input: root = {TreeNode} [5,4,8,3,null,6,3]
Output: 7
Constraints:
Each tree has at most 40000 nodes..
Each node’s value is between [-4 * 10^4 , 4 * 10^4].
Hints:
Create a datastructure with 4 parameters: (sum, isBST, maxLeft, minLeft).
In each node compute theses parameters, following the conditions of a Binary Search Tree.
[1]:
class TreeNode(object):
def __init__(self, x):
self.val = x
self.left = None
self.right = None
1. DFS solution with stack¶
[2]:
class Solution1(object):
"""
Time: O(N)
Space: O(H)
"""
def maxSumBST(self, root):
"""
:type root: TreeNode
:rtype: int
"""
result = 0
stk = [[root, None, []]]
while stk:
node, tmp, ret = stk.pop()
if tmp:
lvalid, lsum, lmin, lmax = tmp[0]
rvalid, rsum, rmin, rmax = tmp[1]
if lvalid and rvalid and lmax < node.val < rmin:
total = lsum + node.val + rsum
result = max(result, total)
ret[:] = [True, total, min(lmin, node.val), max(node.val, rmax)]
continue
ret[:] = [False, 0, 0, 0]
continue
if not node:
ret[:] = [True, 0, float("inf"), float("-inf")]
continue
new_tmp = [[], []]
stk.append([node, new_tmp, ret])
stk.append([node.right, None, new_tmp[1]])
stk.append([node.left, None, new_tmp[0]])
return result
[3]:
s = Solution1()
root = TreeNode(1)
root.left = TreeNode(4)
root.right = TreeNode(3)
root.left.left = TreeNode(2)
root.left.right = TreeNode(4)
root.right.left = TreeNode(2)
root.right.right = TreeNode(5)
root.right.right.left = TreeNode(4)
root.right.right.right = TreeNode(6)
assert s.maxSumBST(root) == 20
root = TreeNode(4)
root.left = TreeNode(3)
root.left.left = TreeNode(1)
root.left.right = TreeNode(2)
assert s.maxSumBST(root) == 2
root = TreeNode(-4)
root.left = TreeNode(-2)
root.right = TreeNode(-5)
assert s.maxSumBST(root) == 0
root = TreeNode(2)
root.left = TreeNode(1)
root.right = TreeNode(3)
assert s.maxSumBST(root) == 6
root = TreeNode(5)
root.left = TreeNode(4)
root.right = TreeNode(8)
root.left.left = TreeNode(3)
root.right.left = TreeNode(6)
root.right.right = TreeNode(3)
assert s.maxSumBST(root) == 7
2. DFS solution with recursion¶
[4]:
class Solution2(object):
def maxSumBST(self, root):
"""
:type root: TreeNode
:rtype: int
"""
def dfs(node, result):
if not node:
return True, 0, float("inf"), float("-inf")
lvalid, lsum, lmin, lmax = dfs(node.left, result)
rvalid, rsum, rmin, rmax = dfs(node.right, result)
if lvalid and rvalid and lmax < node.val < rmin:
total = lsum + node.val + rsum
result[0] = max(result[0], total)
return True, total, min(lmin, node.val), max(node.val, rmax)
return False, 0, 0, 0
result = [0]
dfs(root, result)
return result[0]
[5]:
s = Solution2()
root = TreeNode(1)
root.left = TreeNode(4)
root.right = TreeNode(3)
root.left.left = TreeNode(2)
root.left.right = TreeNode(4)
root.right.left = TreeNode(2)
root.right.right = TreeNode(5)
root.right.right.left = TreeNode(4)
root.right.right.right = TreeNode(6)
assert s.maxSumBST(root) == 20
root = TreeNode(4)
root.left = TreeNode(3)
root.left.left = TreeNode(1)
root.left.right = TreeNode(2)
assert s.maxSumBST(root) == 2
root = TreeNode(-4)
root.left = TreeNode(-2)
root.right = TreeNode(-5)
assert s.maxSumBST(root) == 0
root = TreeNode(2)
root.left = TreeNode(1)
root.right = TreeNode(3)
assert s.maxSumBST(root) == 6
root = TreeNode(5)
root.left = TreeNode(4)
root.right = TreeNode(8)
root.left.left = TreeNode(3)
root.right.left = TreeNode(6)
root.right.right = TreeNode(3)
assert s.maxSumBST(root) == 7